Proving Convexity of Mean Squared Error Loss in a Regression Setting.
Proving Convexity of Mean Squared Error Loss - Case Study¶
- In this blog post, we shall quickly cover the convexity proof for Mean Squared Error Loss function used in a traditional Regression setting.
- In case you haven't checked out my previous blog - The Curious Case of Convex Functions, I would highly recommend you to check it out. The blog focuses on all the basic building blocks for proving convexity.
With that in mind, let us start by reviewing -
- The MSE loss for a Regression Algorithm.
- Conditions for checking Convexity.
1. MSE Loss Function -¶
The MSE loss function in a Regression setting is defined as -
$$ \begin{align} J(W) = \frac{1}{2m}\sum_{i=1}^{m} [y^{(i)} - \hat{y}^{(i)}]^2 \tag{1} \end{align} $$Where,
$J(w)$ = Loss as a function of Regression Coeffients.
$y^{(i)}$ = true value for the $ith$ training example.
$\hat{y}^{(i)}$ = predicted value for the $ith$ training example.
For $ith$ training example, $\hat{y}^{(i)}$ is defined as -
$$ \begin{align} \hat{y}^{(i)} = \sum_{j = 1}^{n}(w_jx_{j}^{(i)} ) \tag{2} \end{align} $$Where,
For the sake of convenience/readability, let's assume n = 3. The eq.(2) can thus be written as -
$$ \begin{align} \hat{y}^{(i)} &= \sum_{j = 1}^{n}(w_jx_{j}^{(i)} ) \\ & = w_1x_{1}^{(i)} + w_2x_{2}^{(i)} + w_3x_{3}^{(i)} \tag{3} \end{align} $$Since we have considered only one training example, we can let go of the training index.
$$ \begin{align} \therefore J(W) = \frac{1}{2} [y - (w_1x_{1} + w_2x_{2} + w_3x_{3})]^2 \tag{4} \end{align} $$2. Checking for Convexity of J(W)-¶
For checking the convexity of Mean-Squared-Error function, we shall perform the following checks -
- Step 1 - Computing the Hessian of J(W)
- Step 2- Computing the Principal Minors of the Hessian.
- Step 3 - Based on the values of principal minors, determine the definiteness of Hessian.
- Step 4 - Comment on Convexity based on convexity tests.
Let us get down to it right away-
Step 1 - Hessian of $J(w)$ -¶
$$ \begin{align} J^H = \begin{bmatrix} \frac{\partial ^2 J}{\partial w_1^2} & \frac{\partial^2 J}{\partial w_1 \partial w_2} & \frac{\partial^2 J}{\partial w_1 \partial w_3} & \\ \frac{\partial ^2 J}{\partial w_2 \partial w_1} & \frac{\partial^2 J}{\partial {w_2}^2 } & \frac{\partial^2 J}{\partial w_2 \partial w_3} \\ \frac{\partial ^2 J}{\partial w_3 \partial w_1} & \frac{\partial^2 J}{\partial w_3 \partial w_2 } & \frac{\partial^2 J}{\partial {w_3}^2 } \\ \end{bmatrix} \end{align} $$Lets compute each component of the matrix.
$$ \begin{align} \frac{\partial ^2 J}{\partial w_1^2} &= \frac{\partial}{\partial w_1} \big[ \frac{\partial}{\partial w_1}\big[\frac{1}{2} [y - (w_1x_{1} + w_2x_{2} + w_3x_{3})]^2\big] \big] \\ &= \frac{\partial}{\partial w_1} [y - (w_1x_{1} + w_2x_{2} + w_3x_{3})](-x_1) \\ &= (-x_1)(-x_1) \\ &= (x_1)^2 \end{align} $$$$ \begin{align} \frac{\partial ^2 J}{\partial w_1w_2} &= \frac{\partial}{\partial w_1} \big[ \frac{\partial}{\partial w_2}\big[\frac{1}{2} [y - (w_1x_{1} + w_2x_{2} + w_3x_{3})]^2\big] \big] \\ &= \frac{\partial}{\partial w_1}[y - (w_1x_{1} + w_2x_{2} + w_3x_{3})](-x_2) \\ & = (-x_2)(-x_1) \\ & = x_1x_2 \\ &= \frac{\partial ^2 J}{w_2w_1} \end{align} $$Similarly, it can be proven that -
$$ \begin{align} \frac{\partial ^2 J}{\partial w_1w_3} = \frac{\partial ^2 J}{\partial w_3w_1} = x_1x_3 \\ \frac{\partial ^2 J}{\partial w_2w_3} = \frac{\partial ^2 J}{\partial w_3w_2} = x_2x_3 \\ \end{align} $$$$ \begin{align} \frac{\partial ^2 J}{\partial w_2^2} = x_2^2 \\ \frac{\partial ^2 J}{\partial w_3^2} = x_3^2 \end{align} $$$$ \begin{align} \therefore J^H = \begin{bmatrix} x_1^2 & x_1x_2 & x_1x_3 & \\ x_2x_1 & x_2^2 & x_2x_3 \\ x_3x_1 & x_3x_2 & x_3^2 \\ \end{bmatrix} \end{align} $$Step 2 - Computing the Principal Minors -¶
From previous blog post, a function is convex if all the principal minors are greater than or equal to zero i.e. $\bigtriangleup_k$ $\geq 0 \;\; \forall$ k .
compute $\bigtriangleup_1$ -¶
Principal Minors of order 1 ($\bigtriangleup_1$) can be obtained by deleting any 3-1 = 2 rows and corresponding columns.
a. By deleting row 2 and 3 along with corresponding columns $ \bigtriangleup_1 $ = x_1^2
b. By deleting row 1 and 3 along with corresponding columns $ \bigtriangleup_1 $ = x_2^2
c. By deleting row 1 and 2 along with corresponding columns $ \bigtriangleup_1 $ = x_3^2
compute $\bigtriangleup_2$ -¶
Principal Minors of order 2 can be obtained by deleting any 3-2 = 1 row and corresponding column.
a. By deleting row 1 and corresponding column 1 -
b. By deleting row 2 and corresponding column 2 $$ \begin{align} \bigtriangleup_2 & = \begin{vmatrix} x_1^2 & x_1x_3 \\ x_3x_1 & x_3^2 \end{vmatrix} \\ & = x_1^2x_3^2 - (x_1x_3)(x_3x_1) \\ & = x_1^2x_3^2 - x_1^2x_3^2 \\ & = 0 \end{align} $$
c. By deleting row 3 and corresponding column 3 $$ \begin{align} \bigtriangleup_2 & = \begin{vmatrix} x_1^2 & x_1x_2 \\ x_2x_1 & x_2^2 \end{vmatrix} \\ & = x_1^2x_2^2 - (x_1x_2)(x_2x_1) \\ & = x_1^2x_2^2 - x_1^2x_2^2 \\ & = 0 \end{align} $$
compute $\bigtriangleup_3$ -¶
Principal Minors of order 3 can be obtained by computing determinant of J(W).
$$ \begin{align} \bigtriangleup_3 & = \begin{vmatrix}J^H \end{vmatrix} \\ &= \begin{vmatrix} x_1^2 & x_1x_2 & x_1x_3 & \\ x_2x_1 & x_2^2 & x_2x_3 \\ x_3x_1 & x_3x_2 & x_3^2 \\ \end{vmatrix}\\ &= x_1^2 * (x_2^2x_3^2 - x_2^2x_3^2) - x_1x_2 * (x_1x_2x_3^2 - x_1x_2x_3^2) + x_1x_3(x_1x_2^2x_3 - x_1x_2^2x_3) \\ &= 0 \end{align} $$Step 3 - Comment on Definiteness of Hessian of J(w) -¶
- The principal minors of order 1 have a squared form. We know that a squared function is always positive.
- The principal minors of order 2 and 3 are equal zero.
- It can be concluded that $\bigtriangleup_k \geq 0 \;\; \forall k$
- Hence the Hessian of J(w) is Positive Semidefinite.
Step 4 Comment on convexity -¶
Before we comment on the convexity of J(W), let's revise the conditions for convexity -
If $X^H$ is the Hessian Matrix of f(x) then -
- f(x) is strictly convex in $S$ if $X^H$ is a Postive Definite Matrix.
- f(x) is convex in $S$ if $X^H$ is a Postive Semi-Definite Matrix.
- f(x) is strictly concave in $S$ if $X^H$ is a Negative Definite Matrix.
- f(x) is concave in $S$ if $X^H$ is a Negative Semi-Definite Matrix.
Since the Hessian of J(w) is Positive Semidefinite, it can be concluded that the function J(w) is convex.
Final Comments -¶
- This blog post is aimed at proving the convexity of MSE loss function in a Regression setting by simplifying the problem.
- There are different ways of proving the convexity but I found this easier to comprehend.
- Feel free to try out the process for different loss functions that you may have encountered.